"""Sup3r model software"""

import logging

import numpy as np
import tensorflow as tf

from sup3r.models.base import Sup3rGan

logger = logging.getLogger(__name__)


class SolarCC(Sup3rGan):
    """Solar climate change model.

    Note
    ----
    *Modifications to standard Sup3rGan*
        - Pointwise content loss (MAE/MSE) is only on the center 2 daylight
          hours (POINT_LOSS_HOURS) of the daily true + synthetic days and the
          temporal mean of the 24hours of synthetic for n_days
          (usually just 1 day)
        - Discriminator only sees n_days of the center 8 daylight hours
          (DAYLIGHT_HOURS and STARTING_HOUR) of the daily true high res sample.
        - Discriminator sees random n_days of 8-hour samples (DAYLIGHT_HOURS)
          of the daily synthetic high res sample.
        - Includes padding on high resolution output of :meth:`generate` so
          that forward pass always outputs a multiple of 24 hours.
    """

    STARTING_HOUR = 8
    """Starting hour is the hour that daylight starts at, typically
    zero-indexed and rolled to local time"""

    DAYLIGHT_HOURS = 8
    """Daylight hours is the number of daylight hours to sample, so for example
    if STARTING_HOUR is 8 and DAYLIGHT_HOURS is 8, the daylight slice will be
    slice(8, 16). """

    POINT_LOSS_HOURS = 2
    """Number of hours from the center of the day to calculate pointwise loss
    from, e.g., MAE/MSE based on data from the true 4km hourly high res
    field."""

    def __init__(self, *args, t_enhance=None, **kwargs):
        """Add optional t_enhance adjustment.

        Parameters
        ----------
        *args : list
            List of arguments to parent class
        t_enhance : int | None
            Optional argument to fix or update the temporal enhancement of the
            model. This can be used to manipulate the output shape to match
            whatever padded shape the sup3r forward pass module expects. If
            this differs from the t_enhance value based on model layers the
            output will be padded so that the output shape matches low_res *
            t_enhance for the time dimension.
        **kwargs : Mappable
            Keyword arguments for parent class
        """
        super().__init__(*args, **kwargs)
        self._t_enhance = t_enhance or self.t_enhance
        self.meta['t_enhance'] = self._t_enhance

    def init_weights(self, lr_shape, hr_shape, device=None):
        """Initialize the generator and discriminator weights with device
        placement.

        Parameters
        ----------
        lr_shape : tuple
            Shape of one batch of low res input data for sup3r resolution. Note
            that the batch size (axis=0) must be included, but the actual batch
            size doesn't really matter.
        hr_shape : tuple
            Shape of one batch of high res input data for sup3r resolution.
            Note that the batch size (axis=0) must be included, but the actual
            batch size doesn't really matter.
        device : str | None
            Option to place model weights on a device. If None,
            self.default_device will be used.
        """

        # The high resolution data passed to the discriminator should only have
        # daylight hours in the temporal axis=3
        if hr_shape[3] != self.DAYLIGHT_HOURS:
            hr_shape = hr_shape[0:3] + (self.DAYLIGHT_HOURS,) + hr_shape[-1:]

        super().init_weights(lr_shape, hr_shape, device=device)

    @tf.function
    def calc_loss(
        self,
        hi_res_true,
        hi_res_gen,
        weight_gen_advers=0.001,
        train_gen=True,
        train_disc=False,
    ):
        """Calculate the GAN loss function using generated and true high
        resolution data.

        Parameters
        ----------
        hi_res_true : tf.Tensor
            Ground truth high resolution spatiotemporal data.
        hi_res_gen : tf.Tensor
            Super-resolved high resolution spatiotemporal data generated by the
            generative model.
        weight_gen_advers : float
            Weight factor for the adversarial loss component of the generator
            vs. the discriminator.
        train_gen : bool
            True if generator is being trained, then loss=loss_gen
        train_disc : bool
            True if disc is being trained, then loss=loss_disc

        Returns
        -------
        loss : tf.Tensor
            0D tensor representing the loss value for the network being trained
            (either generator or one of the discriminators)
        loss_details : dict
            Namespace of the breakdown of loss components
        """

        if hi_res_gen.shape != hi_res_true.shape:
            msg = (
                'The tensor shapes of the synthetic output {} and '
                'true high res {} did not have matching shape! '
                'Check the spatiotemporal enhancement multipliers in your '
                'your model config and data handlers.'.format(
                    hi_res_gen.shape, hi_res_true.shape
                )
            )
            logger.error(msg)
            raise RuntimeError(msg)

        msg = (
            'Special SolarCC model can only accept multi-day hourly '
            '(multiple of 24) true / synthetic high res data in the axis=3 '
            'position but received shape {}'.format(hi_res_true.shape)
        )
        assert hi_res_true.shape[3] % 24 == 0

        t_len = hi_res_true.shape[3]
        n_days = int(t_len // 24)

        # slices for 24-hour full days
        day_24h_slices = [slice(x, x + 24) for x in range(0, 24 * n_days, 24)]

        # slices for middle-daylight-hours
        sub_day_slices = [
            slice(
                self.STARTING_HOUR + x,
                self.STARTING_HOUR + x + self.DAYLIGHT_HOURS,
            )
            for x in range(0, 24 * n_days, 24)
        ]

        # slices for middle-pointwise-loss-hours
        point_loss_slices = [
            slice(
                (24 - self.POINT_LOSS_HOURS) // 2 + x,
                (24 - self.POINT_LOSS_HOURS) // 2 + x + self.POINT_LOSS_HOURS,
            )
            for x in range(0, 24 * n_days, 24)
        ]

        # sample only daylight hours for disc training and gen content loss
        disc_out_true = []
        disc_out_gen = []
        loss_gen_content = 0.0
        ziter = zip(sub_day_slices, point_loss_slices, day_24h_slices)
        for tslice_sub, tslice_ploss, tslice_24h in ziter:
            hr_true_sub = hi_res_true[:, :, :, tslice_sub, :]
            hr_gen_24h = hi_res_gen[:, :, :, tslice_24h, :]
            hr_true_ploss = hi_res_true[:, :, :, tslice_ploss, :]
            hr_gen_ploss = hi_res_gen[:, :, :, tslice_ploss, :]

            hr_true_mean = tf.math.reduce_mean(hr_true_sub, axis=3)
            hr_gen_mean = tf.math.reduce_mean(hr_gen_24h, axis=3)

            gen_c_sub = self.calc_loss_gen_content(hr_true_ploss, hr_gen_ploss)
            gen_c_24h = self.calc_loss_gen_content(hr_true_mean, hr_gen_mean)
            loss_gen_content += gen_c_24h + gen_c_sub

            disc_t = self._tf_discriminate(hr_true_sub)
            disc_out_true.append(disc_t)

        # Randomly sample daylight windows from generated data. Better than
        # strided samples covering full day because the random samples will
        # provide an evenly balanced training set for the disc
        logits = [[1.0] * (t_len - self.DAYLIGHT_HOURS + 1)]
        time_samples = tf.random.categorical(logits, n_days)
        for i in range(n_days):
            t0 = time_samples[0, i]
            t1 = t0 + self.DAYLIGHT_HOURS
            disc_g = self._tf_discriminate(hi_res_gen[:, :, :, t0:t1, :])
            disc_out_gen.append(disc_g)

        disc_out_true = tf.concat([disc_out_true], axis=0)
        disc_out_gen = tf.concat([disc_out_gen], axis=0)
        loss_disc = self.calc_loss_disc(disc_out_true, disc_out_gen)

        loss_gen_content /= len(sub_day_slices)
        loss_gen_advers = self.calc_loss_gen_advers(disc_out_gen)
        loss_gen = loss_gen_content + weight_gen_advers * loss_gen_advers

        loss = None
        if train_gen:
            loss = loss_gen
        elif train_disc:
            loss = loss_disc

        loss_details = {
            'loss_gen': loss_gen,
            'loss_gen_content': loss_gen_content,
            'loss_gen_advers': loss_gen_advers,
            'loss_disc': loss_disc,
        }

        return loss, loss_details

    def temporal_pad(self, low_res, hi_res, mode='reflect'):
        """Optionally add temporal padding to the 5D generated output array

        Parameters
        ----------
        low_res : np.ndarray
            Low-resolution input data to the spatio(temporal) GAN, which is a
            5D array of shape: (1, spatial_1, spatial_2, n_temporal,
            n_features).
        hi_res : ndarray
            Synthetically generated high-resolution data output from the
            (spatio)temporal GAN with a 5D array shape:
            (1, spatial_1, spatial_2, n_temporal, n_features)
        mode : str
            Padding mode for np.pad()

        Returns
        -------
        hi_res : ndarray
            Synthetically generated high-resolution data output from the
            (spatio)temporal GAN with a 5D array shape:
            (1, spatial_1, spatial_2, n_temporal, n_features)
            With the temporal axis padded with self._temporal_pad on either
            side.
        """
        t_shape = low_res.shape[-2] * self._t_enhance
        t_pad = int((t_shape - hi_res.shape[-2]) / 2)
        pad_width = ((0, 0), (0, 0), (0, 0), (t_pad, t_pad), (0, 0))
        prepad_shape = hi_res.shape
        hi_res = np.pad(hi_res, pad_width, mode=mode)
        logger.debug(
            'Padded hi_res output from %s to %s', prepad_shape, hi_res.shape
        )
        return hi_res

    def generate(self, low_res, **kwargs):
        """Override parent method to apply padding on high res output."""

        hi_res = self.temporal_pad(
            low_res, super().generate(low_res=low_res, **kwargs)
        )

        logger.debug('Final SolarCC output has shape: {}'.format(hi_res.shape))

        return hi_res

    @classmethod
    def load(cls, model_dir, t_enhance=None, verbose=True):
        """Load the GAN with its sub-networks from a previously saved-to output
        directory.

        Parameters
        ----------
        model_dir : str
            Directory to load GAN model files from.
        t_enhance : int | None
            Optional argument to fix or update the temporal enhancement of the
            model. This can be used to manipulate the output shape to match
            whatever padded shape the sup3r forward pass module expects. If
            this differs from the t_enhance value based on model layers the
            output will be padded so that the output shape matches low_res *
            t_enhance for the time dimension.
        verbose : bool
            Flag to log information about the loaded model.

        Returns
        -------
        out : BaseModel
            Returns a pretrained gan model that was previously saved to out_dir
        """
        fp_gen, fp_disc, params = cls._load(model_dir, verbose=verbose)
        return cls(fp_gen, fp_disc, t_enhance=t_enhance, **params)
